# Spectral GCN + Attention Recovery + LSTM
# This code trains and tests the GNN model for the COVID-19 infection prediction in Tokyo
# Author: Jiawei Xue, August 26, 2021
# Step 1: read and pack the traning and testing data
# Step 2: training epoch, training process, testing
# Step 3: build the model = spectral GCN + Attention Recovery + LSTM
# Step 4: main function
# Step 5: evaluation
# Step 6: visualization
import os
import csv
import json
import copy
import time
import random
import string
import argparse
import numpy as np
import pandas as pd
import geopandas as gpd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
import torch.nn.functional as F
from Memory_spectral_T3_GCN_memory_light_1_layer import SpecGCN
from Memory_spectral_T3_GCN_memory_light_1_layer import SpecGCN_LSTM
#torch.set_printoptions(precision=8)
#hyperparameter for the setting
X_day, Y_day = 21,7
#START_DATE, END_DATE = '20200414','20210207'
#START_DATE, END_DATE = '20200808','20210603'
START_DATE, END_DATE = '20200720','20210515'
WINDOW_SIZE = 7
#hyperparameter for the learning
DROPOUT, ALPHA = 0.50, 0.20
NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE = 100, 8, 0.0001
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2 = 6,4,3
infection_normalize_ratio = 100.0
web_search_normalize_ratio = 100.0
train_ratio = 0.7
validate_ratio = 0.1
SEED = 5
torch.manual_seed(SEED)
r = random.random
random.seed(5)
#1.total period (mobility+text):
#from 20200201 to 20210620: (29+31+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20)\
#= 335 + 171 = 506;
#2.number of zones: 23;
#3.infection period:
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#1. Mobility: functions 1.2 to 1.7
#2. Text: functions 1.8 to 1.14
#3. InfectionL: functions 1.15
#4. Preprocess: functions 1.16 to 1.24
#5. Learn: functions 1.25 to 1.26
#function 1.1
#get the central areas of Tokyo (e.g., the Special wards of Tokyo)
#return: a 23 zone shapefile
def read_tokyo_23():
folder = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/present_model_version10/tokyo_23"
file = "tokyo_23zones.shp"
path = os.path.join(folder,file)
data = gpd.read_file(path)
return data
##################1.Mobility#####################
#function 1.2
#get the average of two days' mobility (infection) records
def mob_inf_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = (record1[i]+record2[i])/2.0
return new_record
#function 1.3
#get the average of multiple days' mobility (infection) records
def mob_inf_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record:
if zone_id not in list(new_record.keys()):
new_record[zone_id] = record[zone_id]
else:
new_record[zone_id] += record[zone_id]
for new_record_key in new_record:
new_record[new_record_key] = new_record[new_record_key]*1.0/num_day
return new_record
#function 1.4
#generate the dateList: [20200101, 20200102, ..., 20211231]
def generate_dateList():
yearList = ["2020","2021"]
monthList = ["0"+str(i+1) for i in range(9)] + ["10","11","12"]
dayList = ["0"+str(i+1) for i in range(9)] + [str(i) for i in range(10,32)]
day_2020_num = [31,29,31,30,31,30,31,31,30,31,30,31]
day_2021_num = [31,28,31,30,31,30,31,31,30,31,30,31]
date_2020, date_2021 = list(), list()
for i in range(12):
for j in range(day_2020_num[i]):
date_2020.append(yearList[0] + monthList[i] + dayList[j])
for j in range(day_2021_num[i]):
date_2021.append(yearList[1] + monthList[i] + dayList[j])
date_2020_2021 = date_2020 + date_2021
return date_2020_2021
#function 1.5
#smooth the mobility (infection) data using the neighborhood average
#under a given window size
#dateList: [20200101, 20200102, ..., 20211231]
def mob_inf_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = mob_inf_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.6
#set the mobility (infection) of one day as zero
def mob_inf_average_null(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
new_record[i] = 0
return new_record
#function 1.7
#read the mobility data from "mobility_feature_20200201.json"...
#return: all_mobility:{"20200201":{('123','123'):12345,...},...}
#20200201 to 20210620: 506 days
def read_mobility_data(jcode23):
all_mobility = dict()
mobilityFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/mobility_20210804"
mobilityNameList = os.listdir(mobilityFilePath)
for i in range(len(mobilityNameList)):
day_mobility = dict()
file_name = mobilityNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0] #get the day
file_path = mobilityFilePath + '/' + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
for key in df_file:
origin, dest = key.split("_")[0], key.split("_")[1]
if origin in jcode23 and dest in jcode23:
if origin == dest:
day_mobility[(origin, dest)] = 0.0 #ignore the inner-zone flow
else:
day_mobility[(origin, dest)] = df_file[key]
all_mobility[day] = day_mobility
#missing data
all_mobility["20201128"] = mob_inf_average(all_mobility,"20201127","20201129")
all_mobility["20210104"] = mob_inf_average(all_mobility, "20210103","20210105")
return all_mobility
##################2.Text#####################
#function 1.8
#get the average of two days' infection records
def text_average(data, key1, key2):
new_record = dict()
record1, record2 = data[key1], data[key2]
for i in record1:
if i in record2:
zone_record1, zone_record2 = record1[i], record2[i]
new_zone_record = dict()
for j in zone_record1:
if j in zone_record2:
new_zone_record[j] = (zone_record1[j] + zone_record2[j])/2.0
new_record[i] = new_zone_record
return new_record
#function 1.9
#get the average of multiple days' text records
def text_average_multiple(data, keyList):
new_record = dict()
num_day = len(keyList)
for i in range(num_day):
record = data[keyList[i]]
for zone_id in record: #zone_id
if zone_id not in new_record:
new_record[zone_id] = dict()
for j in record[zone_id]: #symptom
if j not in new_record[zone_id]:
new_record[zone_id][j] = record[zone_id][j]
else:
new_record[zone_id][j] += record[zone_id][j]
for zone_id in new_record:
for j in new_record[zone_id]:
new_record[zone_id][j] = new_record[zone_id][j]*1.0/num_day
return new_record
#function 1.10
#smooth the text data using the neighborhood average
#under a given window size
def text_smooth(data, window_size, dateList):
data_copy = copy.copy(data)
data_key_list = list(data_copy.keys())
for data_key in data_key_list:
left = int(max(dateList.index(data_key)-(window_size-1)/2, 0))
right = int(min(dateList.index(data_key)+(window_size-1)/2, len(dateList)-1))
potential_neighbor = dateList[left:right+1]
neighbor_data_key = list(set(data_key_list).intersection(set(potential_neighbor)))
data_average = text_average_multiple(data_copy, neighbor_data_key)
data[data_key] = data_average
return data
#function 1.11
#read the number of user points
def read_point_json():
with open('user_point/mobility_user_point.json') as point1:
user_point1 = json.load(point1)
with open('user_point/mobility_user_point_20210812.json') as point2:
user_point2 = json.load(point2)
user_point_all = dict()
for i in user_point1:
user_point_all[i] = user_point1[i]
for i in user_point2:
user_point_all[i] = user_point2[i]
user_point_all["20201128"] = user_point_all["20201127"] #data missing
user_point_all["20210104"] = user_point_all["20210103"] #data missing
return user_point_all
#function 1.12
#normalize the text search by the number of user points.
def normalize_text_user(all_text, user_point_all):
for day in all_text:
if day in user_point_all:
num_user = user_point_all[day]["num_user"]
all_text_day_new = dict()
all_text_day = all_text[day]
for zone in all_text_day:
if zone not in all_text_day_new:
all_text_day_new[zone] = dict()
for sym in all_text_day[zone]:
all_text_day_new[zone][sym] = all_text_day[zone][sym]*1.0/num_user
all_text[day] = all_text_day_new
return all_text
#function 1.13
#read the text data
#20200201 to 20210620: 506 days
#all_text = {"20200211":{"123":{"code":3,"fever":2,...},...},...}
def read_text_data(jcode23):
all_text = dict()
textFilePath = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/text_20210804"
textNameList = os.listdir(textFilePath)
for i in range(len(textNameList)):
day_text = dict()
file_name = textNameList[i]
if "20" in file_name:
day = (file_name.split("_")[2]).split(".")[0]
file_path = textFilePath + "/" + file_name
f = open(file_path,)
df_file = json.load(f) #read the mobility file
f.close()
new_dict = dict()
for key in df_file:
if key in jcode23:
new_dict[key] = {key1:df_file[key][key1]*1.0*web_search_normalize_ratio for key1 in df_file[key]}
#new_dict[key] = df_file[key]*WEB_SEARCH_RATIO
all_text[day] = new_dict
all_text["20201030"] = text_average(all_text, "20201029", "20201031") #data missing
return all_text
#function 1.14
#perform the min-max normalization for the text data.
def min_max_text_data(all_text,jcode23):
#calculate the min_max
#region_key: sym: [min,max]
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい', '吐き気', '嘔吐', '筋肉痛', '動悸', \
'副鼻腔炎', '発疹', 'くしゃみ', '倦怠感', '寒気', '脱水', \
'中咽頭', '関節痛', '不眠症', '睡眠障害', '鼻漏', '片頭痛', \
'多汗症', 'ほてり', '胸痛', '発汗', '無気力', '呼吸困難', \
'喘鳴', '目の痛み', '体の痛み', '無嗅覚症', '耳の痛み', \
'錯乱', '見当識障害', '胸の圧迫感', '鼻の乾燥', '耳感染症', \
'味覚消失', '上気道感染症', '眼感染症', '食欲減少'])
region_sym_min_max = dict()
for key in jcode23: #initialize
region_sym_min_max[key] = dict()
for sym in text_list:
region_sym_min_max[key][sym] = [1000000,0] #min, max
for day in all_text: #update
for key in jcode23:
for sym in text_list:
if sym in all_text[day][key]:
count = all_text[day][key][sym]
if count < region_sym_min_max[key][sym][0]:
region_sym_min_max[key][sym][0] = count
if count > region_sym_min_max[key][sym][1]:
region_sym_min_max[key][sym][1] = count
#print ("region_sym_min_max",region_sym_min_max)
for key in jcode23: #normalize
for sym in text_list:
min_count,max_count=region_sym_min_max[key][sym][0],region_sym_min_max[key][sym][1]
for day in all_text:
if sym in all_text[day][key]:
if max_count-min_count == 0:
all_text[day][key][sym] = 1
else:
all_text[day][key][sym] = (all_text[day][key][sym]-min_count)*1.0/(max_count-min_count)
#print("all_text[day][key][sym]",all_text[day][key][sym])
return all_text
##################3.Infection#####################
#function 1.15
#read the infection data
#20200331 to 20210620: (1+30+31+30+31+31+30+31+30+31)+(31+28+31+30+31+20) = 276 + 171 = 447.
#all_infection = {"20200201":{"123":1,"123":2}}
def read_infection_data(jcode23):
all_infection = dict()
infection_path = "/data/HSEES/xue/xue_codes/disease_prediction_ml/gml_code/"+\
"present_model_version10/patient_20210725.json"
f = open(infection_path,)
df_file = json.load(f) #read the mobility file
f.close()
for zone_id in df_file:
for one_day in df_file[zone_id]:
daySplit = one_day.split("/")
year, month, day = daySplit[0], daySplit[1], daySplit[2]
if len(month) == 1:
month = "0" + month
if len(day) == 1:
day = "0" + day
new_date = year + month + day
if str(zone_id[0:5]) in jcode23:
if new_date not in all_infection:
all_infection[new_date] = {zone_id[0:5]:df_file[zone_id][one_day]*1.0/infection_normalize_ratio}
else:
all_infection[new_date][zone_id[0:5]] = df_file[zone_id][one_day]*1.0/infection_normalize_ratio
#missing
date_list = [str(20200316+i) for i in range(15)]
for date in date_list:
all_infection[date] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200514'] = mob_inf_average(all_infection,'20200513','20200515')
all_infection['20200519'] = mob_inf_average(all_infection,'20200518','20200520')
all_infection['20200523'] = mob_inf_average(all_infection,'20200522','20200524')
all_infection['20200530'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20200531'] = mob_inf_average(all_infection,'20200529','20200601')
all_infection['20201231'] = mob_inf_average(all_infection,'20201230','20210101')
all_infection['20210611'] = mob_inf_average(all_infection,'20210610','20210612')
#outlier
all_infection['20200331'] = mob_inf_average(all_infection,'20200401','20200401')
all_infection['20200910'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200911'] = mob_inf_average(all_infection,'20200909','20200912')
all_infection['20200511'] = mob_inf_average(all_infection,'20200510','20200512')
all_infection['20201208'] = mob_inf_average(all_infection,'20201207','20201209')
all_infection['20210208'] = mob_inf_average(all_infection,'20210207','20210209')
all_infection['20210214'] = mob_inf_average(all_infection,'20210213','20210215')
#calculate the subtraction
all_infection_subtraction = dict()
all_infection_subtraction['20200331'] = all_infection['20200331']
all_keys = list(all_infection.keys())
all_keys.sort()
for i in range(len(all_keys)-1):
record = dict()
for j in all_infection[all_keys[i+1]]:
record[j] = all_infection[all_keys[i+1]][j] - all_infection[all_keys[i]][j]
all_infection_subtraction[all_keys[i+1]] = record
return all_infection_subtraction, all_infection
##################4.Preprocess#####################
#function 1.16
#ensemble the mobility, text, and infection.
#all_mobility = {"20200201":{('123','123'):12345,...},...}
#all_text = {"20200201":{"123":{"cold":3,"fever":2,...},...},...}
#all_infection = {"20200316":{"123":1,"123":2}}
#all_x_y = {"0":[[mobility_1,text_1, ..., mobility_x_day,text_x_day], [infection_1,...,infection_y_day],\
#[infection_1,...,infection_x_day]],0}
#x_days, y_days: use x_days to predict y_days
def ensemble(all_mobility, all_text, all_infection, x_days, y_days, all_day_list):
all_x_y = dict()
for j in range(len(all_day_list) - x_days - y_days + 1):
x_sample, y_sample, x_sample_infection = list(), list(), list()
#add the data from all_day_list[0+j] to all_day_list[x_days-1+j]
for k in range(x_days):
day = all_day_list[k + j]
x_sample.append(all_mobility[day])
x_sample.append(all_text[day])
x_sample_infection.append(all_infection[day]) #concatenate with the infection data
#add the data from all_day_list[x_days+j] to all_day_list[x_days+y_day-1+j]
for k in range(y_days):
day = all_day_list[x_days + k + j]
y_sample.append(all_infection[day])
all_x_y[str(j)] = [x_sample, y_sample, x_sample_infection,j]
return all_x_y
#function 1.17
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_key = [all_x_y[str(i)] for i in range(n_train)]
validate_key = [all_x_y[str(i+n_train)] for i in range(n_validate)]
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key
##function 1.18
#the second data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_2(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n-n_train-n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train+n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if i % 9 == 8:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.19
#the third data split method
#split the data by train/validate/test = train_ratio/validation_ratio/(1-train_ratio-validation_ratio)
def split_data_3(all_x_y, train_ratio, validation_ratio):
all_x_y_key = list(all_x_y.keys())
n = len(all_x_y_key)
n_train, n_validate = round(n*train_ratio), round(n*validation_ratio)
n_test = n - n_train - n_validate
train_list, validate_list = list(), list()
train_validate_key = [all_x_y[str(i)] for i in range(n_train + n_validate)]
train_key, validate_key = list(), list()
for i in range(len(train_validate_key)):
if (n_train + n_validate-i) % 2 == 0 and (n_train + n_validate-i) <= 2*n_validate:
validate_key.append(all_x_y[str(i)])
validate_list.append(i)
else:
train_key.append(all_x_y[str(i)])
train_list.append(i)
test_key = [all_x_y[str(i+n_train+n_validate)] for i in range(n_test)]
return train_key, validate_key, test_key, train_list, validate_list
##function 1.20
#find the mobility data starting from the day, which is x_days before the start_date
#start_date = "20200331", x_days = 7
def sort_date(all_mobility, start_date, x_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
mobility_date_cut = mobility_date_list[idx-x_days:]
return mobility_date_cut
#function 1.21
#find the mobility data starting from the day, which is x_days before the start_date,
#ending at the day, which is y_days after the end_date
#start_date = "20200331", x_days = 7
def sort_date_2(all_mobility, start_date, x_days, end_date, y_days):
mobility_date_list = list(all_mobility.keys())
mobility_date_list.sort()
idx = mobility_date_list.index(start_date)
idx2 = mobility_date_list.index(end_date)
mobility_date_cut = mobility_date_list[idx-x_days:idx2+y_days]
return mobility_date_cut
#function 1.22
#get the mappings from zone id to id, text id to id.
#get zone_text_to_idx
def get_zone_text_to_idx(all_infection):
zone_list = list(set(all_infection["20200401"].keys()))
text_list = list(['痛み', '頭痛', '咳', '下痢', 'ストレス', '不安', \
'腹痛', 'めまい'])
zone_list.sort()
zone_dict = {str(zone_list[i]):i for i in range(len(zone_list))}
text_dict = {str(text_list[i]):i for i in range(len(text_list))}
return zone_dict, text_dict
#function 1.23
#change the data format to matrix
#zoneid_to_idx = {"13101":0, "13102":1, ..., "13102":22}
#sym_to_idx = {"cough":0}
#mobility: {('13101', '13101'): 709973, ...}
#text: {'13101': {'痛み': 51,...},...} text
#infection: {'13101': 50, '13102': 137, '13103': 401,...}
#data_type = {"mobility", "text", "infection"}
def to_matrix(zoneid_to_idx, sym_to_idx, input_data, data_type):
n_zone, n_text = len(zoneid_to_idx), len(sym_to_idx)
if data_type == "mobility":
result = np.zeros((n_zone, n_zone))
for key in input_data:
from_id, to_id = key[0], key[1]
from_idx, to_idx = zoneid_to_idx[from_id], zoneid_to_idx[to_id]
result[from_idx][to_idx] += input_data[key]
if data_type == "text":
result = np.zeros((n_zone, n_text))
for key1 in input_data:
for key2 in input_data[key1]:
if key1 in list(zoneid_to_idx.keys()) and key2 in list(sym_to_idx.keys()):
zone_idx, text_idx = zoneid_to_idx[key1], sym_to_idx[key2]
result[zone_idx][text_idx] += input_data[key1][key2]
if data_type == "infection":
result = np.zeros(n_zone)
for key in input_data:
zone_idx = zoneid_to_idx[key]
result[zone_idx] += input_data[key]
return result
#function 1.24
#change the data to the matrix format
def change_to_matrix(data, zoneid_to_idx, sym_to_idx):
data_result = list()
for i in range(len(data)):
combine1, combine2 = list(), list()
combine3 = list() #NEW
mobility_text = data[i][0]
x_infection_all = data[i][2] #the x_days infection data
day_order = data[i][3] #NEW the order of the day
for j in range(round(len(mobility_text)*1.0/2)):
mobility, text = mobility_text[2*j], mobility_text[2*j+1]
x_infection = x_infection_all[j] #NEW
new_mobility = to_matrix(zoneid_to_idx, sym_to_idx, mobility, "mobility")
new_text = to_matrix(zoneid_to_idx, sym_to_idx, text, "text")
combine1.append(new_mobility)
combine1.append(new_text)
new_x_infection = to_matrix(zoneid_to_idx, sym_to_idx, x_infection, "infection") #NEW
combine3.append(new_x_infection) #NEW
for j in range(len(data[i][1])):
infection = data[i][1][j]
new_infection = to_matrix(zoneid_to_idx, sym_to_idx, infection, "infection")
combine2.append(new_infection)
data_result.append([combine1,combine2,combine3,day_order]) #mobility/text; infection_y; infection_x; day_order
return data_result
##################5.learn#####################
#function 1.25
def visual_loss(e_losses, vali_loss, test_loss):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1,y2,y3 = copy.copy(e_losses), copy.copy(vali_loss), copy.copy(test_loss)
plt.plot(x,y1,linewidth=1, label="train")
plt.plot(x,y2,linewidth=1, label="validate")
plt.plot(x,y3,linewidth=1, label="test")
plt.legend()
plt.title('Loss decline on entire training/validation/testing data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 1.26
def visual_loss_train(e_losses):
plt.figure(figsize=(4,3), dpi=300)
x = range(len(e_losses))
y1 = copy.copy(e_losses)
plt.plot(x,y1,linewidth=1, label="train")
plt.legend()
plt.title('Loss decline on entire training data')
plt.xlabel('Epoch')
plt.ylabel('Loss')
#plt.savefig('final_f6.png',bbox_inches = 'tight')
plt.show()
#function 2.1
#normalize each column of the input mobility matrix as one
def normalize_column_one(input_matrix):
column_sum = np.sum(input_matrix, axis=0)
row_num, column_num = len(input_matrix), len(input_matrix[0])
for i in range(row_num):
for j in range(column_num):
input_matrix[i][j] = input_matrix[i][j]*1.0/column_sum[j]
return input_matrix
#function 2.2
#evalute the trained_model on validation or testing data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x)
loss = criterion(y_hat.float(), y_real.float()) ###Calculate the loss
return loss, y_hat, y_real
#function 2.3
#convert the mobility matrix in x_batch in a following way
#normalize the flow between zones so that the in-flow of each zone is 1.
def convertAdj(x_batch):
#x_batch:(n_batch, 0/1, 2*i+1)
x_batch_new = copy.copy(x_batch)
n_batch = len(x_batch)
days = round(len(x_batch[0][0])/2)
for i in range(n_batch):
for j in range(days):
mobility_matrix = x_batch[i][0][2*j]
x_batch_new[i][0][2*j] = normalize_column_one(mobility_matrix) #20210818
return x_batch_new
#function 2.4
#a training epoch
def train_epoch_option(model, opt, criterion, trainX_c, trainY_c, batch_size):
model.train()
losses = []
batch_num = 0
for beg_i in range(0, len(trainX_c), batch_size):
batch_num += 1
if batch_num % 16 ==0:
print ("batch_num: ", batch_num, "total batch number: ", int(len(trainX_c)/batch_size))
x_batch = trainX_c[beg_i:beg_i+batch_size]
y_batch = torch.tensor(trainY_c[beg_i:beg_i+batch_size])
opt.zero_grad()
x_batch = convertAdj(x_batch) #conduct the column normalization
y_hat = model.run_specGCN_lstm(x_batch) ###Attention
loss = criterion(y_hat.float(), y_batch.float()) #MSE loss
#opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.data.numpy())
return sum(losses)/float(len(losses)), model
#function 2.5
#multiple training epoch
def train_process(train_data, lr, num_epochs, net, criterion, bs, vali_data, test_data):
opt = optim.Adam(net.parameters(), lr, betas = (0.9,0.999), weight_decay=0)
train_y = [train_data[i][1] for i in range(len(train_data))]
e_losses = list()
e_losses_vali = list()
e_losses_test = list()
time00 = time.time()
for e in range(num_epochs):
time1 = time.time()
print ("current epoch: ",e, "total epoch: ", num_epochs)
number_list = list(range(len(train_data)))
random.shuffle(number_list, random = r)
trainX_sample = [train_data[number_list[j]] for j in range(len(number_list))]
trainY_sample = [train_y[number_list[j]] for j in range(len(number_list))]
loss, net = train_epoch_option(net, opt, criterion, trainX_sample, trainY_sample, bs)
print ("train loss", loss*infection_normalize_ratio*infection_normalize_ratio)
e_losses.append(loss*infection_normalize_ratio*infection_normalize_ratio)
loss_vali, y_hat_vali, y_real_vali = validate_test_process(net, vali_data)
loss_test, y_hat_test, y_real_test = validate_test_process(net, test_data)
e_losses_vali.append(float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
e_losses_test.append(float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
print ("validate loss", float(loss_vali)*infection_normalize_ratio*infection_normalize_ratio)
print ("test loss", float(loss_test)*infection_normalize_ratio*infection_normalize_ratio)
if e>=2 and (e+1)%10 ==0:
visual_loss(e_losses, e_losses_vali, e_losses_test)
visual_loss_train(e_losses)
time2 = time.time()
print ("running time for this epoch:", time2 - time1)
time01 = time.time()
print ("---------------------------------------------------------------")
print ("---------------------------------------------------------------")
#print ("total running time until now:", time01 - time00)
#print ("------------------------------------------------")
#print("specGCN_weight", net.specGCN.layer1.W)
#print("specGCN_weight_grad", net.specGCN.layer1.W.grad)
#print ("------------------------------------------------")
#print("memory decay matrix", net.v)
#print("memory decay matrix grad", net.v.grad)
#print ("------------------------------------------------")
#print ("lstm weight", net.lstm.all_weights[0][0])
#print ("lstm weight grad", net.lstm.all_weights[0][0].grad)
#print ("------------------------------------------------")
#print ("fc1.weight", net.fc1.weight)
#print ("fc1 weight grd", net.fc1.weight.grad)
#print ("---------------------------------------------------------------")
#print ("---------------------------------------------------------------")
return e_losses, net
#function 3.1
def read_data():
jcode23 = list(read_tokyo_23()["JCODE"]) #1.1 get the tokyo 23 zone shapefile
all_mobility = read_mobility_data(jcode23) #1.2 read the mobility data
all_text = read_text_data(jcode23) #1.3 read the text data
all_infection, all_infection_cum = read_infection_data(jcode23) #1.4 read the infection data
#smooth the data using 7-days average
window_size = WINDOW_SIZE #20210818
dateList = generate_dateList() #20210818
all_mobility = mob_inf_smooth(all_mobility, window_size, dateList) #20210818
all_infection = mob_inf_smooth(all_infection, window_size, dateList) #20210818
#smooth, user, min-max.
point_json = read_point_json() #20210821
all_text = normalize_text_user(all_text, point_json) #20210821
all_text = text_smooth(all_text, window_size, dateList) #20210818
all_text = min_max_text_data(all_text,jcode23) #20210820
x_days, y_days = X_day, Y_day
mobility_date_cut = sort_date_2(all_mobility, START_DATE, x_days, END_DATE, y_days)
all_x_y = ensemble(all_mobility, all_text, all_infection, x_days, y_days, mobility_date_cut)
train_original, validate_original, test_original, train_list, validation_list =\
split_data_3(all_x_y,train_ratio,validate_ratio)
zone_dict, text_dict = get_zone_text_to_idx(all_infection) #get zone_idx, text_idx
train_x_y = change_to_matrix(train_original, zone_dict, text_dict) #get train
print ("train_x_y_shape",len(train_x_y),"train_x_y_shape[0]",len(train_x_y[0]))
validate_x_y = change_to_matrix(validate_original, zone_dict, text_dict) #get validate
test_x_y = change_to_matrix(test_original, zone_dict, text_dict) #get test
print (len(train_x_y)) #300
print (len(train_x_y[0][0])) #14
print (np.shape(train_x_y[0][0][0])) #(23,23)
print (np.shape(train_x_y[0][0][1])) #(23,43)
#print ("---------------------------------finish data reading and preprocessing------------------------------------")
return train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, train_original, validate_original, test_original, train_list, validation_list
#function 3.2
#train the model
def model_train(train_x_y, vali_data, test_data):
#3.2.1 define the model
input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2 = len(train_x_y[0][0][1][1]),\
HIDDEN_DIM_1, OUT_DIM_1, HIDDEN_DIM_2
dropout_1, alpha_1, N = DROPOUT, ALPHA, len(train_x_y[0][0][1])
G_L_Model = SpecGCN_LSTM(X_day, Y_day, input_dim_1, hidden_dim_1, out_dim_1, hidden_dim_2, dropout_1,N) ###Attention
#3.2.2 train the model
num_epochs, batch_size, learning_rate = NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE #model train
criterion = nn.MSELoss()
e_losses, trained_model = train_process(train_x_y, learning_rate, num_epochs, G_L_Model, criterion, batch_size,\
vali_data, test_data)
return e_losses, trained_model
#function 3.3
#evaluate the error on validation (or testing) data.
def validate_test_process(trained_model, vali_test_data):
criterion = nn.MSELoss()
vali_test_y = [vali_test_data[i][1] for i in range(len(vali_test_data))]
y_real = torch.tensor(vali_test_y)
vali_test_x = [vali_test_data[i] for i in range(len(vali_test_data))]
vali_test_x = convertAdj(vali_test_x)
y_hat = trained_model.run_specGCN_lstm(vali_test_x) ###Attention
loss = criterion(y_hat.float(), y_real.float())
return loss, y_hat, y_real
#4.1
#read the data
train_x_y, validate_x_y, test_x_y, all_mobility, all_infection, \
train_original, validate_original, test_original, train_list, validation_list =\
read_data()
#train_x_y, validate_x_y, test_x_y = normalize(train_x_y, validate_x_y, test_x_y)
#train_x_y = train_x_y[0:30]
print (len(train_x_y))
print ("---------------------------------finish data preparation------------------------------------")
train_x_y_shape 210 train_x_y_shape[0] 4 210 42 (23, 23) (23, 8) 210 ---------------------------------finish data preparation------------------------------------
#4.2
#train the model
e_losses, trained_model = model_train(train_x_y, validate_x_y, test_x_y)
print ("---------------------------finish model training-------------------------")
current epoch: 0 total epoch: 100 batch_num: 16 total batch number: 26 train loss 389.51339562320044 validate loss 461.5815728902817 test loss 420.6988587975502 running time for this epoch: 11.212058305740356 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 1 total epoch: 100 batch_num: 16 total batch number: 26 train loss 368.56581560439537 validate loss 458.8644206523895 test loss 397.93893694877625 running time for this epoch: 11.142239570617676 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 2 total epoch: 100 batch_num: 16 total batch number: 26 train loss 357.02208815901366 validate loss 455.12083917856216 test loss 372.27313965559006 running time for this epoch: 11.084539651870728 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 3 total epoch: 100 batch_num: 16 total batch number: 26 train loss 340.4463249845086 validate loss 449.12345707416534 test loss 340.2675688266754 running time for this epoch: 12.57164478302002 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 4 total epoch: 100 batch_num: 16 total batch number: 26 train loss 324.8221785933883 validate loss 442.94368475675583 test loss 308.1196919083595 running time for this epoch: 12.298977375030518 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 5 total epoch: 100 batch_num: 16 total batch number: 26 train loss 306.06165166116426 validate loss 434.76883322000504 test loss 277.72000059485435 running time for this epoch: 13.091440200805664 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 6 total epoch: 100 batch_num: 16 total batch number: 26 train loss 292.51017262814224 validate loss 424.82949793338776 test loss 245.9038607776165 running time for this epoch: 13.562950372695923 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 7 total epoch: 100 batch_num: 16 total batch number: 26 train loss 278.3871799086531 validate loss 410.0314527750015 test loss 218.21854636073112 running time for this epoch: 13.611157894134521 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 8 total epoch: 100 batch_num: 16 total batch number: 26 train loss 268.54746595576955 validate loss 392.03185588121414 test loss 192.99430772662163 running time for this epoch: 13.639788389205933 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 9 total epoch: 100 batch_num: 16 total batch number: 26 train loss 248.5437646370243 validate loss 369.55833435058594 test loss 173.30141738057137
running time for this epoch: 14.157571792602539 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 10 total epoch: 100 batch_num: 16 total batch number: 26 train loss 238.87640658628055 validate loss 343.24347972869873 test loss 154.88519333302975 running time for this epoch: 13.769270658493042 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 11 total epoch: 100 batch_num: 16 total batch number: 26 train loss 222.6441920976396 validate loss 312.69174069166183 test loss 137.47898861765862 running time for this epoch: 13.907568216323853 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 12 total epoch: 100 batch_num: 16 total batch number: 26 train loss 211.22331748268118 validate loss 277.2226929664612 test loss 121.15377001464367 running time for this epoch: 13.936947584152222 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 13 total epoch: 100 batch_num: 16 total batch number: 26 train loss 203.02332851483865 validate loss 249.6902272105217 test loss 110.01433245837688 running time for this epoch: 13.871631383895874 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 14 total epoch: 100 batch_num: 16 total batch number: 26 train loss 186.86929320985521 validate loss 224.39738735556602 test loss 99.88406673073769 running time for this epoch: 13.849193334579468 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 15 total epoch: 100 batch_num: 16 total batch number: 26 train loss 176.60433667953365 validate loss 199.91692155599594 test loss 91.20824746787548 running time for this epoch: 13.92893934249878 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 16 total epoch: 100 batch_num: 16 total batch number: 26 train loss 183.0442265296976 validate loss 177.74635925889015 test loss 82.30838924646378 running time for this epoch: 13.8875892162323 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 17 total epoch: 100 batch_num: 16 total batch number: 26 train loss 158.20242809476673 validate loss 158.32774341106415 test loss 76.15369278937578 running time for this epoch: 13.848512887954712 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 18 total epoch: 100 batch_num: 16 total batch number: 26 train loss 151.3788235132341 validate loss 142.8230106830597 test loss 70.04291284829378 running time for this epoch: 13.993308305740356 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 19 total epoch: 100 batch_num: 16 total batch number: 26 train loss 143.55373586079588 validate loss 128.42942960560322 test loss 65.51651284098625
running time for this epoch: 14.817188024520874 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 20 total epoch: 100 batch_num: 16 total batch number: 26 train loss 136.05992792657128 validate loss 116.03317223489285 test loss 60.8947966247797 running time for this epoch: 13.978629112243652 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 21 total epoch: 100 batch_num: 16 total batch number: 26 train loss 130.38865118115038 validate loss 106.10681027173996 test loss 57.53133911639452 running time for this epoch: 14.497379302978516 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 22 total epoch: 100 batch_num: 16 total batch number: 26 train loss 126.63233266384512 validate loss 97.56474755704403 test loss 53.80216985940933 running time for this epoch: 14.11842679977417 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 23 total epoch: 100 batch_num: 16 total batch number: 26 train loss 119.6907017448986 validate loss 88.59043940901756 test loss 50.269956700503826 running time for this epoch: 14.061521053314209 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 24 total epoch: 100 batch_num: 16 total batch number: 26 train loss 114.5596920374643 validate loss 81.95559494197369 test loss 48.567852936685085 running time for this epoch: 14.87638545036316 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 25 total epoch: 100 batch_num: 16 total batch number: 26 train loss 117.04486881003335 validate loss 77.02444680035114 test loss 46.907076612114906 running time for this epoch: 14.238161325454712 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 26 total epoch: 100 batch_num: 16 total batch number: 26 train loss 105.33256999527416 validate loss 70.69777697324753 test loss 43.63874439150095 running time for this epoch: 14.310251712799072 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 27 total epoch: 100 batch_num: 16 total batch number: 26 train loss 111.18224580530767 validate loss 66.218008287251 test loss 41.44130740314722 running time for this epoch: 14.637001752853394 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 28 total epoch: 100 batch_num: 16 total batch number: 26 train loss 104.31487156354167 validate loss 62.088239938020706 test loss 39.36328459531069 running time for this epoch: 14.680642127990723 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 29 total epoch: 100 batch_num: 16 total batch number: 26 train loss 95.38590278545465 validate loss 58.51362366229296 test loss 37.426555063575506
running time for this epoch: 14.700392246246338 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 30 total epoch: 100 batch_num: 16 total batch number: 26 train loss 92.26096198997564 validate loss 56.03833124041557 test loss 36.713830195367336 running time for this epoch: 14.47653079032898 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 31 total epoch: 100 batch_num: 16 total batch number: 26 train loss 106.43479940309018 validate loss 53.79063542932272 test loss 34.89114809781313 running time for this epoch: 14.374626874923706 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 32 total epoch: 100 batch_num: 16 total batch number: 26 train loss 86.285595488907 validate loss 51.325117237865925 test loss 33.6888269521296 running time for this epoch: 14.463376998901367 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 33 total epoch: 100 batch_num: 16 total batch number: 26 train loss 83.3360614531018 validate loss 49.999901093542576 test loss 32.299370504915714 running time for this epoch: 14.526355504989624 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 34 total epoch: 100 batch_num: 16 total batch number: 26 train loss 81.90255663875077 validate loss 48.64496644586325 test loss 31.54016099870205 running time for this epoch: 14.3942129611969 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 35 total epoch: 100 batch_num: 16 total batch number: 26 train loss 81.49268907598324 validate loss 47.0972852781415 test loss 31.03417344391346 running time for this epoch: 14.556771039962769 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 36 total epoch: 100 batch_num: 16 total batch number: 26 train loss 79.47222847077582 validate loss 46.668192371726036 test loss 30.743922106921673 running time for this epoch: 14.486931562423706 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 37 total epoch: 100 batch_num: 16 total batch number: 26 train loss 90.8887631225365 validate loss 45.87551113218069 test loss 29.338807798922062 running time for this epoch: 14.56943130493164 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 38 total epoch: 100 batch_num: 16 total batch number: 26 train loss 73.423373978585 validate loss 45.96552811563015 test loss 28.272224590182304 running time for this epoch: 14.49048376083374 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 39 total epoch: 100 batch_num: 16 total batch number: 26 train loss 76.08541468572285 validate loss 45.39188463240862 test loss 28.419524896889925
running time for this epoch: 14.868130683898926 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 40 total epoch: 100 batch_num: 16 total batch number: 26 train loss 70.58639881304568 validate loss 44.88480743020773 test loss 27.53889886662364 running time for this epoch: 14.554182052612305 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 41 total epoch: 100 batch_num: 16 total batch number: 26 train loss 82.25280532820356 validate loss 44.8004063218832 test loss 27.800847310572863 running time for this epoch: 14.395141839981079 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 42 total epoch: 100 batch_num: 16 total batch number: 26 train loss 67.23650920860193 validate loss 45.444522984325886 test loss 27.223906945437193 running time for this epoch: 14.519075870513916 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 43 total epoch: 100 batch_num: 16 total batch number: 26 train loss 66.04528162386958 validate loss 45.347665436565876 test loss 26.81305631995201 running time for this epoch: 14.439956665039062 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 44 total epoch: 100 batch_num: 16 total batch number: 26 train loss 65.04926153596628 validate loss 44.92238163948059 test loss 26.941404212266207 running time for this epoch: 14.461411952972412 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 45 total epoch: 100 batch_num: 16 total batch number: 26 train loss 64.08107557540966 validate loss 45.123291201889515 test loss 26.732080150395632 running time for this epoch: 14.423280000686646 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 46 total epoch: 100 batch_num: 16 total batch number: 26 train loss 63.17637705554564 validate loss 45.92851269990206 test loss 26.727975346148014 running time for this epoch: 14.434575080871582 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 47 total epoch: 100 batch_num: 16 total batch number: 26 train loss 61.96851587716352 validate loss 46.41465377062559 test loss 26.065714191645384 running time for this epoch: 14.544530868530273 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 48 total epoch: 100 batch_num: 16 total batch number: 26 train loss 63.86460473381534 validate loss 46.22017499059439 test loss 26.61996055394411 running time for this epoch: 14.54348635673523 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 49 total epoch: 100 batch_num: 16 total batch number: 26 train loss 60.532658007340856 validate loss 47.365142963826656 test loss 26.065483689308167
running time for this epoch: 14.879498481750488 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 50 total epoch: 100 batch_num: 16 total batch number: 26 train loss 61.587861661488816 validate loss 48.317122273147106 test loss 25.416139978915453 running time for this epoch: 14.331295490264893 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 51 total epoch: 100 batch_num: 16 total batch number: 26 train loss 59.057442574865284 validate loss 49.36553072184324 test loss 24.998895823955536 running time for this epoch: 14.694218635559082 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 52 total epoch: 100 batch_num: 16 total batch number: 26 train loss 58.73600290484588 validate loss 49.708569422364235 test loss 25.09842859581113 running time for this epoch: 14.62620234489441 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 53 total epoch: 100 batch_num: 16 total batch number: 26 train loss 59.25949054977132 validate loss 49.666655249893665 test loss 24.905677419155836 running time for this epoch: 14.667057991027832 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 54 total epoch: 100 batch_num: 16 total batch number: 26 train loss 57.1745315445932 validate loss 49.73188042640686 test loss 25.478433817625046 running time for this epoch: 14.625141143798828 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 55 total epoch: 100 batch_num: 16 total batch number: 26 train loss 56.614874456836674 validate loss 50.17014220356941 test loss 25.294723454862833 running time for this epoch: 14.614569902420044 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 56 total epoch: 100 batch_num: 16 total batch number: 26 train loss 56.674776909252 validate loss 51.1766504496336 test loss 24.791057221591473 running time for this epoch: 14.524628400802612 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 57 total epoch: 100 batch_num: 16 total batch number: 26 train loss 55.73964682062743 validate loss 51.05555057525635 test loss 25.322833098471165 running time for this epoch: 14.55009937286377 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 58 total epoch: 100 batch_num: 16 total batch number: 26 train loss 55.165095427246 validate loss 52.65612620860338 test loss 24.398567620664835 running time for this epoch: 14.506810426712036 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 59 total epoch: 100 batch_num: 16 total batch number: 26 train loss 54.78232419894387 validate loss 52.846334874629974 test loss 24.335209745913744
running time for this epoch: 14.900669574737549 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 60 total epoch: 100 batch_num: 16 total batch number: 26 train loss 54.4567381169785 validate loss 53.9866229519248 test loss 24.032224901020527 running time for this epoch: 14.50723671913147 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 61 total epoch: 100 batch_num: 16 total batch number: 26 train loss 55.632724435310124 validate loss 54.78001665323973 test loss 23.947295267134905 running time for this epoch: 14.74116063117981 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 62 total epoch: 100 batch_num: 16 total batch number: 26 train loss 53.81633286122922 validate loss 55.73894828557968 test loss 23.35519064217806 running time for this epoch: 14.645407438278198 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 63 total epoch: 100 batch_num: 16 total batch number: 26 train loss 61.63635795625548 validate loss 55.75014743953943 test loss 23.579117842018604 running time for this epoch: 14.630415201187134 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 64 total epoch: 100 batch_num: 16 total batch number: 26 train loss 52.917095168528185 validate loss 57.21705500036478 test loss 23.01332773640752 running time for this epoch: 14.296960830688477 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 65 total epoch: 100 batch_num: 16 total batch number: 26 train loss 52.603048237937465 validate loss 57.617868296802044 test loss 23.083253763616085 running time for this epoch: 14.235386610031128 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 66 total epoch: 100 batch_num: 16 total batch number: 26 train loss 52.857313529346825 validate loss 57.7155826613307 test loss 22.899922914803028 running time for this epoch: 14.312468767166138 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 67 total epoch: 100 batch_num: 16 total batch number: 26 train loss 52.896079917748764 validate loss 57.78695922344923 test loss 23.180642165243626 running time for this epoch: 14.256875276565552 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 68 total epoch: 100 batch_num: 16 total batch number: 26 train loss 52.12549046889223 validate loss 57.59128835052252 test loss 23.30072456970811 running time for this epoch: 14.305672407150269 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 69 total epoch: 100 batch_num: 16 total batch number: 26 train loss 51.65311533750759 validate loss 56.89933896064758 test loss 23.448390420526266
running time for this epoch: 14.66124176979065 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 70 total epoch: 100 batch_num: 16 total batch number: 26 train loss 51.392384856525396 validate loss 58.50153975188732 test loss 23.05364701896906 running time for this epoch: 14.262706279754639 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 71 total epoch: 100 batch_num: 16 total batch number: 26 train loss 51.39643964530142 validate loss 59.78522356599569 test loss 22.495337761938572 running time for this epoch: 14.423526048660278 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 72 total epoch: 100 batch_num: 16 total batch number: 26 train loss 51.975091488135085 validate loss 59.953141026198864 test loss 22.82840432599187 running time for this epoch: 14.384004354476929 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 73 total epoch: 100 batch_num: 16 total batch number: 26 train loss 52.24682647234725 validate loss 57.94981028884649 test loss 23.213049862533808 running time for this epoch: 14.292196273803711 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 74 total epoch: 100 batch_num: 16 total batch number: 26 train loss 51.07128041951607 validate loss 57.155610993504524 test loss 23.549122270196676 running time for this epoch: 14.336617469787598 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 75 total epoch: 100 batch_num: 16 total batch number: 26 train loss 50.58073253301835 validate loss 60.11572666466236 test loss 22.34967891126871 running time for this epoch: 14.383347272872925 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 76 total epoch: 100 batch_num: 16 total batch number: 26 train loss 49.91981026250869 validate loss 59.342747554183006 test loss 22.064128424972296 running time for this epoch: 14.24332594871521 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 77 total epoch: 100 batch_num: 16 total batch number: 26 train loss 63.14266238068404 validate loss 60.075707733631134 test loss 21.89388731494546 running time for this epoch: 14.190537214279175 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 78 total epoch: 100 batch_num: 16 total batch number: 26 train loss 49.478896070030274 validate loss 61.38438358902931 test loss 21.407175809144974 running time for this epoch: 14.335057735443115 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 79 total epoch: 100 batch_num: 16 total batch number: 26 train loss 49.97094774066851 validate loss 61.32188718765974 test loss 21.290939766913652
running time for this epoch: 14.640076398849487 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 80 total epoch: 100 batch_num: 16 total batch number: 26 train loss 49.43431982408381 validate loss 61.096022836863995 test loss 21.19217300787568 running time for this epoch: 14.247124671936035 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 81 total epoch: 100 batch_num: 16 total batch number: 26 train loss 49.13923046034243 validate loss 60.202376917004585 test loss 21.73619344830513 running time for this epoch: 14.234225988388062 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 82 total epoch: 100 batch_num: 16 total batch number: 26 train loss 49.028490601038484 validate loss 59.126331470906734 test loss 22.461453918367624 running time for this epoch: 14.18925142288208 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 83 total epoch: 100 batch_num: 16 total batch number: 26 train loss 48.93512355111953 validate loss 60.468404553830624 test loss 21.24272519722581 running time for this epoch: 14.215323448181152 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 84 total epoch: 100 batch_num: 16 total batch number: 26 train loss 49.683899290997665 validate loss 61.11936643719673 test loss 21.10192086547613 running time for this epoch: 14.290733814239502 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 85 total epoch: 100 batch_num: 16 total batch number: 26 train loss 48.660772197224475 validate loss 61.060115694999695 test loss 21.17878757417202 running time for this epoch: 14.243019819259644 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 86 total epoch: 100 batch_num: 16 total batch number: 26 train loss 48.395868334746744 validate loss 60.53873337805271 test loss 21.534215193241835 running time for this epoch: 14.270911931991577 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 87 total epoch: 100 batch_num: 16 total batch number: 26 train loss 48.925992601585605 validate loss 60.62899250537157 test loss 21.133413538336754 running time for this epoch: 14.406659126281738 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 88 total epoch: 100 batch_num: 16 total batch number: 26 train loss 48.934770898927965 validate loss 61.28853186964989 test loss 21.08631655573845 running time for this epoch: 14.34745478630066 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 89 total epoch: 100 batch_num: 16 total batch number: 26 train loss 47.84388217592129 validate loss 59.08275954425335 test loss 21.099161822348833
running time for this epoch: 14.700841665267944 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 90 total epoch: 100 batch_num: 16 total batch number: 26 train loss 47.68202191270474 validate loss 60.90676411986351 test loss 20.982257556170225 running time for this epoch: 14.351561069488525 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 91 total epoch: 100 batch_num: 16 total batch number: 26 train loss 48.98731204405151 validate loss 60.6852350756526 test loss 20.6843507476151 running time for this epoch: 14.40815019607544 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 92 total epoch: 100 batch_num: 16 total batch number: 26 train loss 47.841608455335646 validate loss 60.844700783491135 test loss 20.49987204372883 running time for this epoch: 14.48062801361084 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 93 total epoch: 100 batch_num: 16 total batch number: 26 train loss 47.32403084980669 validate loss 59.72822196781635 test loss 20.769359543919563 running time for this epoch: 14.414551496505737 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 94 total epoch: 100 batch_num: 16 total batch number: 26 train loss 47.19516584710045 validate loss 58.92719142138958 test loss 21.34637674316764 running time for this epoch: 14.3018798828125 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 95 total epoch: 100 batch_num: 16 total batch number: 26 train loss 47.124378939366174 validate loss 60.168057680130005 test loss 20.632515661418438 running time for this epoch: 14.371438026428223 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 96 total epoch: 100 batch_num: 16 total batch number: 26 train loss 47.089996514841914 validate loss 60.50659343600273 test loss 20.270119421184063 running time for this epoch: 14.355631351470947 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 97 total epoch: 100 batch_num: 16 total batch number: 26 train loss 46.773913706635575 validate loss 61.37645337730646 test loss 20.17771592363715 running time for this epoch: 14.346758842468262 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 98 total epoch: 100 batch_num: 16 total batch number: 26 train loss 47.02895069149909 validate loss 59.81928203254938 test loss 20.42854903265834 running time for this epoch: 14.380319356918335 --------------------------------------------------------------- --------------------------------------------------------------- current epoch: 99 total epoch: 100 batch_num: 16 total batch number: 26 train loss 46.828605247557995 validate loss 61.28355395048857 test loss 19.926545210182667
running time for this epoch: 14.73263669013977 --------------------------------------------------------------- --------------------------------------------------------------- ---------------------------finish model training-------------------------
#4.3
print (len(train_x_y))
print (len(validate_x_y))
print (len(test_x_y))
#4.3.1 model validation
validation_result, validate_hat, validate_real = validate_test_process(trained_model, validate_x_y)
print ("---------------------------------finish model validation------------------------------------")
print (len(validate_hat))
print (len(validate_real))
#4.3.2 model testing
#4.4. model test
test_result, test_hat, test_real = validate_test_process(trained_model, test_x_y)
print ("---------------------------------finish model testing------------------------------------")
print (len(test_real))
print (len(test_hat))
210 30 60 ---------------------------------finish model validation------------------------------------ 30 30 ---------------------------------finish model testing------------------------------------ 60 60
#5.1 RMSE, MAPE, MAE, RMSLE
def RMSELoss(yhat,y):
return float(torch.sqrt(torch.mean((yhat-y)**2)))
def MAPELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), y)))
def MAELoss(yhat,y):
return float(torch.mean(torch.div(torch.abs(yhat-y), 1)))
def RMSLELoss(yhat,y):
log_yhat = torch.log(yhat+1)
log_y = torch.log(y+1)
return float(torch.sqrt(torch.mean((log_yhat-log_y)**2)))
#compute RMSE
rmse_validate = list()
rmse_test = list()
for i in range(len(validate_x_y)):
rmse_validate.append(float(RMSELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
rmse_test.append(float(RMSELoss(test_hat[i],test_real[i])))
print ("rmse_validate mean", np.mean(rmse_validate))
print ("rmse_test mean", np.mean(rmse_test))
#compute MAE
mae_validate = list()
mae_test = list()
for i in range(len(validate_x_y)):
mae_validate.append(float(MAELoss(validate_hat[i],validate_real[i])))
for i in range(len(test_x_y)):
mae_test.append(float(MAELoss(test_hat[i],test_real[i])))
print ("mae_validate mean", np.mean(mae_validate))
print ("mae_test mean", np.mean(mae_test))
#show RMSE and MAE together
mae_validate, rmse_validate, mae_test, rmse_test =\
np.array(mae_validate)*infection_normalize_ratio, np.array(rmse_validate)*infection_normalize_ratio,\
np.array(mae_test)*infection_normalize_ratio, np.array(rmse_test)*infection_normalize_ratio
print ("-----------------------------------------")
print ("mae_validate mean", round(np.mean(mae_validate),3), " rmse_validate mean", round(np.mean(rmse_validate),3))
print ("mae_test mean", round(np.mean(mae_test),3), " rmse_test mean", round(np.mean(rmse_test),3))
print ("-----------------------------------------")
rmse_validate mean 0.06770370910790816 rmse_test mean 0.042385180168302825 mae_validate mean 0.0475774879994494 mae_test mean 0.03253371020757528 ----------------------------------------- mae_validate mean 4.758 rmse_validate mean 6.77 mae_test mean 3.253 rmse_test mean 4.239 -----------------------------------------
print (trained_model.v)
Parameter containing:
tensor([0.1562, 0.0911, 0.0436, 0.0056, 0.1415, 0.0726, 0.0675, 0.0287, 0.0473,
0.0558, 0.0025, 0.0007, 0.0453, 0.0369, 0.0132, 0.0481, 0.0597, 0.0851,
0.0103, 0.0040, 0.0015, 0.0133, 0.0035], requires_grad=True)
print(validate_hat[0][Y_day-1])
print(torch.sum(validate_hat[0][Y_day-1]))
print(validate_real[0][Y_day-1])
print(torch.sum(validate_real[0][Y_day-1]))
tensor([0.0523, 0.1572, 0.3790, 0.5829, 0.1962, 0.2999, 0.2166, 0.4334, 0.3360,
0.3918, 0.7843, 0.9659, 0.4047, 0.4669, 0.5609, 0.2795, 0.3418, 0.2212,
0.5230, 0.6182, 0.7004, 0.3466, 0.6144], grad_fn=<SelectBackward>)
tensor(9.8733, grad_fn=<SumBackward0>)
tensor([0.0443, 0.1186, 0.2243, 0.4386, 0.1600, 0.2257, 0.1629, 0.3100, 0.2729,
0.1886, 0.6786, 0.7429, 0.2000, 0.3614, 0.5071, 0.2143, 0.2700, 0.1914,
0.5457, 0.4757, 0.5086, 0.4914, 0.4657], dtype=torch.float64)
tensor(7.7986, dtype=torch.float64)
x = range(len(rmse_validate))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_validate), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_validate), 'go-',linewidth=0.8, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of validation',fontsize=12)
plt.ylabel("RMSE/MAE daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
x = range(len(mae_test))
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x, np.array(rmse_test), 'ro-',linewidth=0.8, markersize=1.2, label='RMSE')
l2 = plt.plot(x, np.array(mae_test), 'go-',linewidth=0.5, markersize=1.2, label='MAE')
plt.xlabel('Date from the first day of test',fontsize=12)
plt.ylabel("RMSE/MAE Daily new cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend()
plt.grid()
plt.show()
from scipy import stats
#validate
y_days = Y_day
validate_hat_sum = [float(torch.sum(validate_hat[i][y_days-1])) for i in range(len(validate_hat))]
validate_real_sum = [float(torch.sum(validate_real[i][y_days-1])) for i in range(len(validate_real))]
print ("the correlation between validation: ", stats.pearsonr(validate_hat_sum, validate_real_sum)[0])
#test
test_hat_sum = [float(torch.sum(test_hat[i][y_days-1])) for i in range(len(test_hat))]
test_real_sum = [float(torch.sum(test_real[i][y_days-1])) for i in range(len(test_real))]
print ("the correlation between test: ", stats.pearsonr(test_hat_sum, test_real_sum)[0])
#train
train_result, train_hat, train_real = validate_test_process(trained_model, train_x_y)
train_hat_sum = [float(torch.sum(train_hat[i][0])) for i in range(len(train_hat))]
train_real_sum = [float(torch.sum(train_real[i][0])) for i in range(len(train_real))]
print ("the correlation between train: ", stats.pearsonr(train_hat_sum, train_real_sum)[0])
the correlation between validation: 0.9564129888260149 the correlation between test: 0.8631332812393128 the correlation between train: 0.9672063581661307
y1List = [np.sum(list(train_original[i+1][1][Y_day-1].values())) for i in range(len(train_original)-1)]
y2List = [np.sum(list(validate_original[i][1][Y_day-1].values())) for i in range(len(validate_original))]
y2List_hat = [float(torch.sum(validate_hat[i][Y_day-1])) for i in range(len(validate_hat))]
y3List = [np.sum(list(test_original[i][1][Y_day-1].values())) for i in range(len(test_original))]
y3List_hat = [float(torch.sum(test_hat[i][Y_day-1])) for i in range(len(test_hat))]
#x1 = np.array(range(len(y1List)))
#x2 = np.array([len(y1List)+j for j in range(len(y2List))])
x1 = train_list
x2 = validation_list
x3 = np.array([len(y1List)+len(y2List)+j for j in range(len(y3List))])
plt.figure(figsize=(8,2),dpi=300)
l1 = plt.plot(x1[0: len(y1List)], np.array(y1List)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='train')
l2 = plt.plot(x2, np.array(y2List)*infection_normalize_ratio, 'go-',linewidth=0.8, markersize=2.0, label='validate')
l3 = plt.plot(x2, np.array(y2List_hat)*infection_normalize_ratio, 'g-',linewidth=2, markersize=0.1, label='validate_predict')
l4 = plt.plot(x3, np.array(y3List)*infection_normalize_ratio, 'bo-',linewidth=0.8, markersize=2, label='test')
l5 = plt.plot(x3, np.array(y3List_hat)*infection_normalize_ratio, 'b-',linewidth=2, markersize=0.1, label='test_predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,2100, 500)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(5):
summary += 60
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize=8)
plt.yticks(fontsize=12)
plt.title("SpectralGCN")
plt.legend()
plt.grid()
plt.show()
def getPredictionPlot(k):
#location k
x_k = [i for i in range(len(test_real))]
real_k = [test_real[i][y_days-1][k] for i in range(len(test_real))]
predict_k = [test_hat[i][y_days-1][k] for i in range(len(test_hat))]
plt.figure(figsize=(4,2.5), dpi=300)
l1 = plt.plot(x_k, np.array(real_k)*infection_normalize_ratio, 'ro-',linewidth=0.8, markersize=2.0, label='real',alpha = 0.8)
l2 = plt.plot(x_k, np.array(predict_k)*infection_normalize_ratio, 'o-',color='black',linewidth=0.8, markersize=2.0, alpha = 0.8, label='predict')
#plt.xlabel('Date from the first day of 2020/4/1',fontsize=12)
#plt.ylabel("Daily infection cases",fontsize=10)
my_y_ticks = np.arange(0,100,35)
my_x_ticks = list()
summary = 0
my_x_ticks.append(summary)
for i in range(3):
summary += 20
my_x_ticks.append(summary)
plt.xticks(my_x_ticks)
plt.yticks(my_y_ticks)
plt.xticks(fontsize = 16)
plt.yticks(fontsize = 16)
plt.title("Urban district "+str(k+1))
#plt.legend(fontsize=20)
plt.legend(loc=2,fontsize=15)
#plt.grid()
plt.savefig('peak4/' + str(k+1) + '.pdf',bbox_inches = 'tight')
plt.show()
for i in range(23):
getPredictionPlot(i)